import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

public class Solution662 {

  class NodeInfo {
    public int layerNo;
    public TreeNode node;

    public NodeInfo(int layerNo, TreeNode node) {
      this.layerNo = layerNo;
      this.node = node;
    }
  }

  private int getLayerNo(int lastLayerNo, char direct) {
    return (lastLayerNo - 1) * 2 + (direct == 'l' ? 1 : 2);
  }

  public int widthOfBinaryTree(TreeNode root) {
    int maxWidth = 0;
    List<NodeInfo> layerNodeList = new ArrayList<>(), nextLayerNodeList = new ArrayList<>();
    layerNodeList.add(new NodeInfo(1, root));
    while (!layerNodeList.isEmpty()) {
      for (NodeInfo nodeInfo : layerNodeList) {
        int layerNo = nodeInfo.layerNo;
        TreeNode node = nodeInfo.node;
        if (node.left != null) {
          nextLayerNodeList.add(new NodeInfo(getLayerNo(layerNo, 'l'), node.left));
        }
        if (node.right != null) {
          nextLayerNodeList.add(new NodeInfo(getLayerNo(layerNo, 'r'), node.right));
        }
      }
      maxWidth =
          Math.max(
              maxWidth,
              layerNodeList.get(layerNodeList.size() - 1).layerNo
                  - layerNodeList.get(0).layerNo
                  + 1);
      layerNodeList = nextLayerNodeList;
      nextLayerNodeList = new ArrayList<>();
    }
    return maxWidth;
  }
}
